#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# ovpn.py  -  Utilities for connecting via OpenVPN and setup of keys/certs
#
# Copyright (C) 2015 Jan Jockusch <jan.jockusch@perfact.de>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#

import os
import subprocess

from .generic import to_string

# suppress informational messages which might interfere with output parsing
# e.g.
#  Warning: Permanently added '127.0.0.1' (ECDSA) to the list of known hosts.
sshopts = ' -o LogLevel=ERROR '

# Note regarding shell interpretation of ssh commands:
# shell=False is the default for subprocess, so the local command will not be
# interpreted by a shell. But the remote command, which is the last argument of
# the ssh call, will always be interpreted by a shell. There is no way to avoid
# it, even if multiple arguments are passed where ssh expects the command, it
# glues them together using spaces and gives the result to a shell.
sshcmd = ['ssh'] + sshopts.strip().split()


def ovpn_kill_conn(ovpnconn_id, proxy="vpnnode1", mode="tun"):
    '''Kill a existing vpn connection.
    By default a "tun" connection is terminated.
    ovpnconn_id can be either one ID or a list of IDs.
    '''
    assert mode in ['tun', 'tap'], 'Invalid mode'
    if not isinstance(ovpnconn_id, list):
        ovpnconn_id = [ovpnconn_id]
    # Make sure these are valid integers
    for elem in ovpnconn_id:
        assert str(elem).isdigit(), "Invalid ID"

    idlist = ''.join([
        str(elem) + '\n'
        for elem in ovpnconn_id
    ])

    # see note regarding shell interpretation of ssh commands above
    proc = subprocess.Popen(
        sshcmd + [
            proxy,
            """
            while read ovpnid; do
                /home/mpaproxy/ovpncommand kill $ovpnid {mode} || exit 1
            done
            """.format(mode=mode)
        ],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        universal_newlines=True,
    )
    (out, err) = proc.communicate(idlist)
    assert proc.returncode == 0, "Failed to execute SSH command."
    return out


def ovpn_status(proxy="vpnnode1", mode="tun", version=3):
    '''Retreive ovpn server process internal state table showing connected
    clients, routing information and statistics.
    By default the status version mode 3 is choosen
    '''
    if version not in [3]:
        raise NotImplementedError(
            'Only versions 3 is supported'
        )
    assert mode in ['tun', 'tap'], 'Invalid mode'
    cmd = '/home/mpaproxy/ovpncommand status {} {}'.format(
        version, mode
    )
    # see note regarding shell interpretation of ssh commands above
    out = subprocess.check_output(sshcmd + [proxy, cmd],
                                  universal_newlines=True)
    assert 'OpenVPN' in out, (
        'Unexpected result: %s while calling with: proxy=%s, '
        'mode=%s, version=%s' % (out, proxy, mode, version)
    )
    return out


def ovpn_status_parse(status, version=3):
    '''Parse the output of the openvpn server management console status
    command
    Get a list of connected clients.
    :status - String containing only ASCII characters (output of 'ovpn_status')
    :version - Integer, currently only version=3 is supported. This refers to
    the setting '--status-version 3' of the OpenVPN server process
    (see man openvpn) and the 'status 3' command on the OpenVPN server
    management console
    '''
    # default empty resultset
    results = {'clients': []}

    status = to_string(status)

    if not len(status):
        raise ValueError(
            'The status parameter can not be empty.'
        )

    if version != 3:
        raise NotImplementedError(
            'Only version 3 format is currently supported'
        )

    # store found numbers (client IDs) in here
    clients = []
    # version 3 is a <LF> and <TAB> separated table
    lines = status.split('\n')
    for line in lines:
        # we are looking for lines starting with CLIENT_LIST like:
        # CLIENT_LIST	1099925	80.146.225.220:49260...
        if not line.lower().startswith('client_list'):
            continue
        parts = line.split('\t')
        if not len(parts) >= 2:
            continue
        # we are only interested in the 1099925 part
        cid = parts[1]
        if cid.isdigit() and (cid not in clients):
            clients.append(int(cid))

    results['clients'] = clients
    return results


def ovpn_create_conf(cn, conf, proxy="vpnnode1"):
    '''Write a configuration file for the connecting openVPN Client. This
    could include routing information etc.'''

    cn = str(cn)
    assert cn.isdigit(), "Invalid common name (must be all digits)"
    assert len(cn) <= 100, "Common name much too long"
    assert len(conf) <= 16384, "Configuration string too long (> 16384)"

    # see note regarding shell interpretation of ssh commands above
    proc = subprocess.Popen(
        sshcmd + [proxy, "cat > /etc/openvpn/client-config/{}".format(cn)],
        stdin=subprocess.PIPE,
        universal_newlines=True,
    )
    proc.communicate(conf)
    assert proc.returncode == 0, "Failed to create config"


def ovpn_delete_conf(cn, proxy="vpnnode1"):
    '''Deletes the configuration file for the given client.
    cn can be an integer or a list of integers'''

    if not isinstance(cn, list):
        cn = [cn]
    cn = [str(int(item)) for item in cn]
    for item in cn:
        assert len(item) <= 100, "Common name much too long"

    paths = ' '.join([
        '/etc/openvpn/client-config/{}'.format(int(item))
        for item in cn
    ])

    # check if we can reach the other server and see the directory
    # otherwise an error will be raised
    # see note regarding shell interpretation of ssh commands above
    subprocess.check_call(sshcmd + [proxy,
                                    'stat /etc/openvpn/client-config/'])

    # Remove files if they exist
    subprocess.check_call(sshcmd + [proxy, 'rm -f ' + paths])


def ovpn_list_conf(proxy="vpnnode1"):
    '''Return a list of configurations at the given VPN proxy.'''
    # see note regarding shell interpretation of ssh commands above
    out = subprocess.check_output(
        sshcmd + [proxy, 'ls /etc/openvpn/client-config'],
        universal_newlines=True,
    )
    return out.split()


def ovpn_initialize(subj, expire_days=None, passphrase=None, appca_id=1,
                    ovpn_path='/home/zope/ovpn', rsabits=4096, dhbits=2048,
                    ca_path_fmt="/home/zope/CA/%d"):
    '''Generate a RSA private key and X.509 certificate for the OpenVPN
    server process.

    Generate random parameters for Diffie-Helmann key exchange
    (long prime & generator) - these can be made public.

    Generate a pre shared key (PSK) for additional TLS authentication using
    HMAC and thus preventing denial of service attacks. ('man openvpn =>
    --tls-auth') This PSK has to be available on server & client side.
    '''
    from .cert import cert_makekey, cert_makecsr, ca_signreq

    # Generate RSA key and X.509 certificate signed by the appropriate
    # appca_id
    ca_path = ca_path_fmt % int(appca_id)

    key = cert_makekey(rsabits=rsabits)
    csr = cert_makecsr(key=key, subj=subj, request_type='server',
                       passphrase=passphrase)
    cert = ca_signreq(csr, passphrase=passphrase, days=expire_days,
                      appca_id=appca_id, ca_path_fmt=ca_path_fmt)

    subprocess.call(['mkdir', '-p', '-m', '700', ovpn_path])

    # Remove old files so that the permissions of the new files
    # will be set properly
    for path in [
        'public.pem',
        'private.pem',
        'cacert.pem',
        'dhparam.pem',
        'tlsauth.key'
    ]:
        full_path = os.path.join(ovpn_path, path)
        if os.path.exists(full_path):
            os.remove(full_path)

    # Flags for file descriptor
    flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC

    fd = os.open(
        ovpn_path+'/public.pem',
        flags,
        0o640  # Restricted permissions!
    )
    with os.fdopen(fd, 'w') as fh:
        fh.write(cert)

    fd = os.open(
        ovpn_path+'/private.pem',
        flags,
        0o640  # Restricted permissions!
    )
    with os.fdopen(fd, 'w') as fh:
        fh.write(key)

    # Method for setting the umask for the processes we create
    def preexec():
        # Restrict permissions of newly created files:
        # Remove "write" for "group" and everything for "other"
        os.umask(0o027)

    # Move the CA certificate to our current path
    # Symlink in the /etc/openvpn directory points to it!
    subprocess.check_call(
        ['cp', ca_path+'/certs/cacert.pem', ovpn_path+'/cacert.pem'],
        preexec_fn=preexec
    )

    # Generate Diffie-Hellman parameters (prime & generator)
    # Symlink in the /etc/openvpn directory points to it!
    subprocess.check_call(
        ['openssl', 'dhparam', '-out', ovpn_path+'/dhparam.pem', str(dhbits)],
        preexec_fn=preexec
    )

    # Generate TLS authentication key (PSK) for HMAC
    # Symlink in the /etc/openvpn directory points to it!
    subprocess.check_call(
        ['openvpn', '--genkey', '--secret', ovpn_path+'/tlsauth.key'],
        preexec_fn=preexec
    )

    return True


def ovpn_gettlsauth(ovpn_path='/home/zope/ovpn'):
    with open(ovpn_path+'/tlsauth.key', 'r') as fh:
        content = fh.read()
    index = content.find('-----BEGIN OpenVPN Static key V1-----')
    key = content[index:]
    return key


def assert_safe_for_bash(value):
    '''Checks the input for a possible Shell (Bash) exploit
    Given input is compared with a list of bad/forbidden characters
    If one is found, an AssertionError is raised.
    '''
    # shell security check
    bad = [';',
           '\\',
           '&',
           '(',
           '´',
           '|',
           '<',
           '>',
           '!',
           '$',
           ':',
           '*',
           '?',
           '[',
           '{',
           ]
    assert len(set(value).intersection(bad)) == 0, \
        'Forbidden characters for bash found in: %s' % value
    return True


if __name__ == '__main__':
    print("The Tests have been migrated to pytest "
          "(/perfact/tests/test_ovpn.py)")
